# This script has functions required to predict production for a given load table 
library(pacman)
p_load(
  here, fst, data.table, magrittr, janitor, stringr, 
  purrr, collapse
)

# Function to get coefs nicely formatted from plant models
prep_model_table = function(){
  # Marginal production model
  model_dt = 
    fread(
      here("Data/electricity-generation/output/PPRegressors-oge.csv")
    )[,plant_id_eia := as.character(orispl_code)] |> 
    setkey(plant_id_eia)
  # Creating a long version (so rows are coefs, does not have intercept)
  model_dt_long = 
    melt(
      model_dt, 
      id.vars = 'plant_id_eia',
      measure.vars = patterns('excess_load_|intercept')
    )[value != 0,.(
      plant_id_eia, 
      nerc_adj = str_remove_all(variable, 'excess_load_|_sq'),
      square = str_detect(variable, '_sq$'),
      coef = value
    )] |>
    merge( # Adding the estimated excess load where max/min happens
      melt(
        model_dt,
        id.vars = 'plant_id_eia', 
        measure.vars = patterns('est_extremum')
      )[!is.na(value),.(
        plant_id_eia, 
        nerc_adj = str_remove(variable, 'est_extremum_'),
        est_extremum = value
      )],
      by = c('plant_id_eia','nerc_adj'),
      all.x = TRUE
    )
  # Have to square est extremum for squared coefs 
  model_dt_long[,
    est_extremum := fcase(
      square == TRUE & est_extremum >= 0, est_extremum^2, 
      square == TRUE & est_extremum < 0, -1*(est_extremum^2),
      square == FALSE, est_extremum
    )
  ]
  # Some name cleaning 
  setnames(
    model_dt, 
    old = str_subset(colnames(model_dt), 'sq'),
    new = paste0(
      'excess_load_sq_',
      str_subset(colnames(model_dt), 'sq') |>
        str_remove('_sq') |>
        str_remove('excess_load_')
    ))
  setnames(
    model_dt, 
    old = str_subset(colnames(model_dt), 'excess_load'),
    new = paste0('coef_',str_subset(colnames(model_dt), 'excess_load'))
  )
  # Returning results 
  return(list(model_dt = model_dt, model_dt_long = model_dt_long))

}


# Function to get plant level table ready 
prep_plant_table = function(model_dt){
  # Emissions rates
  emissions_rate_dt = 
    read.fst(
      path = here('Data/electricity-generation/output/emissions-md-dt.fst'),
      as.data.table = TRUE
    )[,plant_id_eia := as.character(plant_id_eia)] |>
    dcast(
      plant_id_eia + net_generation_mwh_split ~ output + pollutant,
      value.var = c('emissions_rate_tons')
    )
  # Marginal damages 
  tot_md_dt = fread(
    here('Data/electricity-generation/output/damage-per-mwh-oge-scc-148.csv')
  )[,plant_id_eia := as.character(plant_id_eia)] 
  # Plant info table
  plant_dt =
    read.fst(
        path = here("Data/electricity-generation/plant-info-dt-oge.fst"),
        as.data.table = TRUE
    )[year == 2019,.(
      plant_id_eia = as.character(plant_id_eia), 
      namepcap, 
      nerc_adj,
      fuel_category
    )] |>
    merge(model_dt, by = 'plant_id_eia')|>
    merge(emissions_rate_dt, by = 'plant_id_eia') |>
    merge(tot_md_dt, by = 'plant_id_eia') %>%
    .[,.(
      plant_id_eia,
      nerc_adj,
      fuel_category,
      namepcap, 
      log_scale, 
      intercept,
      net_generation_mwh_split = net_generation_mwh_split.x, 
      md_per_mwh_high, 
      md_per_mwh_low,
      emissions_rate_low_co2e,
      emissions_rate_high_co2e,
      emissions_rate_low_nox,
      emissions_rate_high_nox,
      emissions_rate_low_so2,
      emissions_rate_high_so2,
      emissions_rate_low_pm25,
      emissions_rate_high_pm25
    )] |>
    setkey(plant_id_eia)
  return(plant_dt)
}

# Function to prep load tables: add squared term and intercept
prep_load_table = function(load_dt_in, id_var = 'datetime_utc'){
  load_dt_in = copy(load_dt_in)
  setnames(load_dt_in, id_var, 'id')
  load_dt_out = 
    rbind(
      load_dt_in[,.(
        id,
        nerc_adj, 
        square = TRUE, 
        excess_load = excess_load^2
      )],
      load_dt_in[,.(
        nerc_adj, 
        id,
        square = FALSE, 
        excess_load
      )],
      load_dt_in[,.(
        nerc_adj = 'intercept', 
        id,
        square = FALSE, 
        excess_load = 1
      )] |> unique()
    )
  return(load_dt_out)
}

# Function to create table of epsilons for each plant 
generate_epsilons = function(plant_dt, n_each = 1000){
  epsilon_dt = 
    map_dfr(
      plant_dt$plant_id_eia,
      \(id){
        data.table(
          plant_id_eia = as.character(id),
          epsilon_id = 1:n_each,
          epsilon = rnorm(
            n_each, mean = 0, 
            sd = exp(plant_dt[plant_id_eia == as.character(id)]$log_scale)
          )
        )
      }
    ) |> 
    setkey(plant_id_eia) 
  return(epsilon_dt)
}

# Function to predict one plant's production for given loads
predict_production_plant = function(
    plant_id_eia_in, 
    load_dt, 
    plant_dt, 
    model_dt_long, 
    epsilon_dt
){
  #print(plant_id_eia_in)
  # First merging together
  prod_dt = 
    merge(
      load_dt,
      model_dt_long[plant_id_eia == plant_id_eia_in],
      by = c('nerc_adj','square'),
      allow.cartesian = TRUE
    )
  # Returning empty table for intercept only models
  if(nrow(prod_dt) == 0){
    stop('Error merging load to plant coefs')
  }
  # First summing coefs 
  coef_prod_dt = 
    prod_dt[,.(
      plant_id_eia, 
      id,
      fit_net_gen_raw = fcase(
        est_extremum <= 0 | is.na(est_extremum), coef*excess_load, 
        est_extremum > 0 & excess_load <= est_extremum, coef*excess_load,
        est_extremum > 0 & excess_load > est_extremum, coef*est_extremum,
        default = NA
      )
    )] |>
    gby(plant_id_eia, id) |>
    fsum() |>
    setkey(plant_id_eia, id)    
  # Merging to intercept and log scale table 
  gen_dt = 
    merge(
      plant_dt,
      coef_prod_dt,
      by = 'plant_id_eia'
    )|>
    merge(
      epsilon_dt, 
      by = 'plant_id_eia',
      allow.cartesian = TRUE
    )
  # Now calculating raw fitted value (censored)
  gen_dt[,
    fit_net_generation_mwh := fcase(
      fit_net_gen_raw + epsilon >= 0 
      & fit_net_gen_raw + epsilon <= namepcap, 
      fit_net_gen_raw + epsilon,
      fit_net_gen_raw + epsilon < 0, 0,
      fit_net_gen_raw + epsilon > namepcap, namepcap
    )
  ]
  # Removing columns we don't need anymore 
  gen_dt[,c('namepcap','log_scale','intercept','epsilon_id','epsilon') := NULL]
  # Now calculating emissions and damages
  gen_dt[,
    fit_emissions_tons_co2e := fifelse(
      fit_net_generation_mwh >= net_generation_mwh_split,
      (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_co2e
      + net_generation_mwh_split*emissions_rate_low_co2e,
      fit_net_generation_mwh*emissions_rate_low_co2e
    )
  ]
  gen_dt[,
    fit_emissions_tons_so2 := fifelse(
      fit_net_generation_mwh >= net_generation_mwh_split,
      (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_so2
      + net_generation_mwh_split*emissions_rate_low_so2,
      fit_net_generation_mwh*emissions_rate_low_so2
    )
  ]
  gen_dt[,
    fit_emissions_tons_nox := fifelse(
      fit_net_generation_mwh >= net_generation_mwh_split,
      (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_nox
      + net_generation_mwh_split*emissions_rate_low_nox,
      fit_net_generation_mwh*emissions_rate_low_nox
    )
  ]
  gen_dt[,
    fit_emissions_tons_pm25 := fifelse(
      fit_net_generation_mwh >= net_generation_mwh_split,
      (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_pm25
      + net_generation_mwh_split*emissions_rate_low_pm25,
      fit_net_generation_mwh*emissions_rate_low_pm25
    )
  ]
  gen_dt[,
    damages := fifelse(
      fit_net_generation_mwh >= net_generation_mwh_split,
      (fit_net_generation_mwh-net_generation_mwh_split)*md_per_mwh_high
      + net_generation_mwh_split*md_per_mwh_low,
      fit_net_generation_mwh*md_per_mwh_low
    )
  ]
  # Taking the average over the epsilon draws
  plant_gen_dt = 
    gen_dt[,.(
      plant_id_eia, 
      id,
      fit_net_gen_raw, 
      fit_net_generation_mwh, 
      fit_emissions_tons_co2e,
      fit_emissions_tons_so2,
      fit_emissions_tons_nox,
      fit_emissions_tons_pm25,
      damages
    )] |>
    gby(plant_id_eia, id) |>
    fmean()
  rm(gen_dt, prod_dt, coef_prod_dt)
  return(plant_gen_dt)
}


# Now running for an entire region 
predict_production_region = function(
  nerc_adj_in, load_dt, plant_dt, model_tables, epsilon_dt, 
  path, add_actual_data, id_var
){
  # Predicting production
  fit_prod_dt = 
    map(
      plant_dt[nerc_adj == nerc_adj_in]$plant_id_eia, 
      predict_production_plant, 
      load_dt = load_dt, 
      plant_dt = plant_dt, 
      model_dt_long = model_tables$model_dt_long, 
      epsilon_dt = epsilon_dt
    ) |> 
    rbindlist() |>
    setnames('id', id_var)
  if(add_actual_data == TRUE){
    # Load actual production data 
    oge_elec_gen_dt=
      read.fst(
        path = here(paste0(
          "Data/electricity-generation/elec-gen-dt/elec-gen-dt-",
          tolower(nerc_adj_in),
          ".fst"
        )),
        as.data.table = TRUE
      )[,.(
        plant_id_eia = as.character(plant_id_eia), 
        year, 
        datetime_utc, 
        net_generation_mwh,
        co2e_mass_lb_for_electricity, 
        nox_mass_lb_for_electricity, 
        so2_mass_lb_for_electricity
      )] |> 
      setkey(plant_id_eia, datetime_utc) 
    # Merging together and saving the results 
    out_dt = 
      merge(
        fit_prod_dt, 
        oge_elec_gen_dt, 
        by = c('plant_id_eia','datetime_utc')
      )
  } else{
    out_dt = fit_prod_dt
  }
  # Save result
  write.fst(
    x = out_dt, 
    path = here(paste0(path,'/',str_to_lower(nerc_adj_in),'.fst'))
  )
  rm(out_dt, fit_prod_dt, oge_elec_gen_dt)
  gc()
}
